import os
import torch


class JustReshape(torch.nn.Module):
    def __init__(self):
        super(JustReshape, self).__init__()

    def forward(self, x):
        return x.view((x.shape[0], x.shape[1], x.shape[3], x.shape[2]))

work_dir = os.getcwd()
model_dir = os.path.join(work_dir, "model")

net = JustReshape()
model_name = 'just_reshape.onnx'
onnx_path = os.path.join(model_dir, model_name)
dummy_input = torch.randn(2, 3, 4, 5)
torch.onnx.export(net, dummy_input, onnx_path, 
                  input_names=['input'], 
                  output_names=['output'], 
                  opset_version=11)
print("torch export onnx model ok.onnx model path:%s" % onnx_path)